import tensorflow as tf

# [batch, width, height, channel]
x = tf.random.normal([4, 32, 32, 3])
print(x.shape)

# [width, height, channel, number]
# number 下一层的channel数量
w = tf.random.normal([5, 5, 3, 4])
b = tf.zeros([4])

# [batch, width, height, channel]
out = tf.nn.conv2d(x, w, strides=1, padding='VALID')
out = out + b
print(out.shape)

# [batch, width, height, channel]
out = tf.nn.conv2d(x, w, strides=2, padding='VALID')
out = out + b
print(out.shape)

# [batch, width, height, channel]
out = tf.nn.conv2d(x, w, strides=1, padding='SAME')
out = out + b
print(out.shape)
